{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "D2H6YLCPOXvZ",
        "outputId": "8688654f-c1d1-45f6-897a-a90e90505b9c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Collecting transformers==4.48.2\n",
            "  Downloading transformers-4.48.2-py3-none-any.whl.metadata (44 kB)\n",
            "\u001b[?25l     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/44.4 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.4/44.4 kB\u001b[0m \u001b[31m3.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from transformers==4.48.2) (3.20.2)\n",
            "Requirement already satisfied: huggingface-hub<1.0,>=0.24.0 in /usr/local/lib/python3.12/dist-packages (from transformers==4.48.2) (0.36.0)\n",
            "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from transformers==4.48.2) (2.0.2)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from transformers==4.48.2) (25.0)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from transformers==4.48.2) (6.0.3)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from transformers==4.48.2) (2025.11.3)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from transformers==4.48.2) (2.32.4)\n",
            "Collecting tokenizers<0.22,>=0.21 (from transformers==4.48.2)\n",
            "  Downloading tokenizers-0.21.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)\n",
            "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.12/dist-packages (from transformers==4.48.2) (0.7.0)\n",
            "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.12/dist-packages (from transformers==4.48.2) (4.67.1)\n",
            "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.24.0->transformers==4.48.2) (2025.3.0)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.24.0->transformers==4.48.2) (4.15.0)\n",
            "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.24.0->transformers==4.48.2) (1.2.0)\n",
            "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->transformers==4.48.2) (3.4.4)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->transformers==4.48.2) (3.11)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->transformers==4.48.2) (2.5.0)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->transformers==4.48.2) (2026.1.4)\n",
            "Downloading transformers-4.48.2-py3-none-any.whl (9.7 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.7/9.7 MB\u001b[0m \u001b[31m81.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading tokenizers-0.21.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m117.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hInstalling collected packages: tokenizers, transformers\n",
            "  Attempting uninstall: tokenizers\n",
            "    Found existing installation: tokenizers 0.22.2\n",
            "    Uninstalling tokenizers-0.22.2:\n",
            "      Successfully uninstalled tokenizers-0.22.2\n",
            "  Attempting uninstall: transformers\n",
            "    Found existing installation: transformers 4.57.3\n",
            "    Uninstalling transformers-4.57.3:\n",
            "      Successfully uninstalled transformers-4.57.3\n",
            "Successfully installed tokenizers-0.21.4 transformers-4.48.2\n",
            "Collecting datasets==2.20.0\n",
            "  Downloading datasets-2.20.0-py3-none-any.whl.metadata (19 kB)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (3.20.2)\n",
            "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (2.0.2)\n",
            "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (18.1.0)\n",
            "Collecting pyarrow-hotfix (from datasets==2.20.0)\n",
            "  Downloading pyarrow_hotfix-0.7-py3-none-any.whl.metadata (3.6 kB)\n",
            "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (0.3.8)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (2.2.2)\n",
            "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (2.32.4)\n",
            "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (4.67.1)\n",
            "Requirement already satisfied: xxhash in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (3.6.0)\n",
            "Requirement already satisfied: multiprocess in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (0.70.16)\n",
            "Collecting fsspec<=2024.5.0,>=2023.1.0 (from fsspec[http]<=2024.5.0,>=2023.1.0->datasets==2.20.0)\n",
            "  Downloading fsspec-2024.5.0-py3-none-any.whl.metadata (11 kB)\n",
            "Requirement already satisfied: aiohttp in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (3.13.3)\n",
            "Requirement already satisfied: huggingface-hub>=0.21.2 in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (0.36.0)\n",
            "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (25.0)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (6.0.3)\n",
            "Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.20.0) (2.6.1)\n",
            "Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.20.0) (1.4.0)\n",
            "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.20.0) (25.4.0)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.20.0) (1.8.0)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.20.0) (6.7.0)\n",
            "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.20.0) (0.4.1)\n",
            "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.20.0) (1.22.0)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.21.2->datasets==2.20.0) (4.15.0)\n",
            "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.21.2->datasets==2.20.0) (1.2.0)\n",
            "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets==2.20.0) (3.4.4)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets==2.20.0) (3.11)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets==2.20.0) (2.5.0)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets==2.20.0) (2026.1.4)\n",
            "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets==2.20.0) (2.9.0.post0)\n",
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets==2.20.0) (2025.2)\n",
            "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets==2.20.0) (2025.3)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas->datasets==2.20.0) (1.17.0)\n",
            "Downloading datasets-2.20.0-py3-none-any.whl (547 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m547.8/547.8 kB\u001b[0m \u001b[31m45.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading fsspec-2024.5.0-py3-none-any.whl (316 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m316.1/316.1 kB\u001b[0m \u001b[31m32.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading pyarrow_hotfix-0.7-py3-none-any.whl (7.9 kB)\n",
            "Installing collected packages: pyarrow-hotfix, fsspec, datasets\n",
            "  Attempting uninstall: fsspec\n",
            "    Found existing installation: fsspec 2025.3.0\n",
            "    Uninstalling fsspec-2025.3.0:\n",
            "      Successfully uninstalled fsspec-2025.3.0\n",
            "  Attempting uninstall: datasets\n",
            "    Found existing installation: datasets 4.0.0\n",
            "    Uninstalling datasets-4.0.0:\n",
            "      Successfully uninstalled datasets-4.0.0\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "gcsfs 2025.3.0 requires fsspec==2025.3.0, but you have fsspec 2024.5.0 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[0mSuccessfully installed datasets-2.20.0 fsspec-2024.5.0 pyarrow-hotfix-0.7\n",
            "\n",
            "Usage:   \n",
            "  pip3 install [options] <requirement specifier> [package-index-options] ...\n",
            "  pip3 install [options] -r <requirements file> [package-index-options] ...\n",
            "  pip3 install [options] [-e] <vcs project url> ...\n",
            "  pip3 install [options] [-e] <local project path> ...\n",
            "  pip3 install [options] <archive url/path> ...\n",
            "\n",
            "no such option: --uprade\n",
            "Collecting peft==0.17.1\n",
            "  Downloading peft-0.17.1-py3-none-any.whl.metadata (14 kB)\n",
            "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from peft==0.17.1) (2.0.2)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from peft==0.17.1) (25.0)\n",
            "Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from peft==0.17.1) (5.9.5)\n",
            "Requirement already satisfied: pyyaml in /usr/local/lib/python3.12/dist-packages (from peft==0.17.1) (6.0.3)\n",
            "Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.12/dist-packages (from peft==0.17.1) (2.9.0+cu126)\n",
            "Requirement already satisfied: transformers in /usr/local/lib/python3.12/dist-packages (from peft==0.17.1) (4.48.2)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from peft==0.17.1) (4.67.1)\n",
            "Requirement already satisfied: accelerate>=0.21.0 in /usr/local/lib/python3.12/dist-packages (from peft==0.17.1) (1.12.0)\n",
            "Requirement already satisfied: safetensors in /usr/local/lib/python3.12/dist-packages (from peft==0.17.1) (0.7.0)\n",
            "Requirement already satisfied: huggingface_hub>=0.25.0 in /usr/local/lib/python3.12/dist-packages (from peft==0.17.1) (0.36.0)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from huggingface_hub>=0.25.0->peft==0.17.1) (3.20.2)\n",
            "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.12/dist-packages (from huggingface_hub>=0.25.0->peft==0.17.1) (2024.5.0)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from huggingface_hub>=0.25.0->peft==0.17.1) (2.32.4)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface_hub>=0.25.0->peft==0.17.1) (4.15.0)\n",
            "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface_hub>=0.25.0->peft==0.17.1) (1.2.0)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (75.2.0)\n",
            "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (1.14.0)\n",
            "Requirement already satisfied: networkx>=2.5.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (3.6.1)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (3.1.6)\n",
            "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (12.6.77)\n",
            "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (12.6.77)\n",
            "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (12.6.80)\n",
            "Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (9.10.2.21)\n",
            "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (12.6.4.1)\n",
            "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (11.3.0.4)\n",
            "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (10.3.7.77)\n",
            "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (11.7.1.2)\n",
            "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (12.5.4.2)\n",
            "Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (0.7.1)\n",
            "Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (2.27.5)\n",
            "Requirement already satisfied: nvidia-nvshmem-cu12==3.3.20 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (3.3.20)\n",
            "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (12.6.77)\n",
            "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (12.6.85)\n",
            "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (1.11.1.6)\n",
            "Requirement already satisfied: triton==3.5.0 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (3.5.0)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from transformers->peft==0.17.1) (2025.11.3)\n",
            "Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.12/dist-packages (from transformers->peft==0.17.1) (0.21.4)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=1.13.0->peft==0.17.1) (1.3.0)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=1.13.0->peft==0.17.1) (3.0.3)\n",
            "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface_hub>=0.25.0->peft==0.17.1) (3.4.4)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface_hub>=0.25.0->peft==0.17.1) (3.11)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface_hub>=0.25.0->peft==0.17.1) (2.5.0)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface_hub>=0.25.0->peft==0.17.1) (2026.1.4)\n",
            "Downloading peft-0.17.1-py3-none-any.whl (504 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m504.9/504.9 kB\u001b[0m \u001b[31m39.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hInstalling collected packages: peft\n",
            "  Attempting uninstall: peft\n",
            "    Found existing installation: peft 0.18.0\n",
            "    Uninstalling peft-0.18.0:\n",
            "      Successfully uninstalled peft-0.18.0\n",
            "Successfully installed peft-0.17.1\n",
            "Downloading...\n",
            "From: https://drive.google.com/uc?id=1_DdLeKnGHZWpOhJ8kRGW1I_vBA4jzlXY\n",
            "To: /content/labelled_data.csv\n",
            "100% 5.50M/5.50M [00:00<00:00, 25.4MB/s]\n"
          ]
        }
      ],
      "source": [
        "!pip install transformers==4.48.2\n",
        "!pip install datasets==2.20.0\n",
        "!pip install --uprade --no-cache-dir gdown==4.5.4\n",
        "!pip install peft==0.17.1\n",
        "\n",
        "!gdown 1_DdLeKnGHZWpOhJ8kRGW1I_vBA4jzlXY"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U3ahFqA9SgjS"
      },
      "outputs": [],
      "source": [
        "import json\n",
        "\n",
        "import pandas as pd\n",
        "\n",
        "def is_positive(x):\n",
        "\treturn 1 if sum(json.loads(x).values()) > 0 else 0\n",
        "\n",
        "data = pd.read_csv(\"labelled_data.csv\")\n",
        "data[\"is_left_wing\"] = data[\"left_wing\"].apply(is_positive)\n",
        "data[\"is_right_wing\"] = data[\"right_wing\"].apply(is_positive)\n",
        "data[\"is_anti_elitism\"] = data[\"anti_elitism\"].apply(is_positive)\n",
        "data[\"is_people_centrism\"] = data[\"people_centrism\"].apply(is_positive)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 233,
          "referenced_widgets": [
            "0b7c9143f16546ccb0114dd6f5f05e1f",
            "588d00688b8444f0a0dd4aa2ef08f84a",
            "748922c942404e878bbe74c957c2b055",
            "ca90952353d941d1a5efd123c4524fc3",
            "035d1e58ab2a4ff7bd80d74fc9581958",
            "c87babf60681452e80d6c68b0fa644c3",
            "60d65ed196de4d6782741c47e548c43d",
            "04e632bbd11b405ca6579c441538a446",
            "67d3b3e533d640beae9618ede0d92c61",
            "7e30e053f1894bd99fbfbff432ae00a7",
            "3a76ddf6f4d7436b882b4c686a338f65",
            "289357be8ae3454286754e55c8f40c3b",
            "501fc476aa9c46efa054d5357d29c254",
            "5afedaf6e44c4612baf7cb9ddb0c2d69",
            "e1aaa790ef8f4a66bdebca8daa6c029c",
            "48c0a120298947c2a6e2082d16375a0c",
            "cf9b78f2b4f64d98af4e66ad8e523757",
            "259407da05744d27ba945cab73d297c3",
            "5d715a75665846f5a174e49f4ee20029",
            "6e05804bb0694170bddb8bc075154e50",
            "ebca5cc51c674a4e8d6363d4dca3e5fb",
            "94a05c4a4a3141d4b53736736105082a",
            "6cebbc05a32843e1a0243addc246a3ea",
            "7e2e00b07f2440c797c8dd4cde418016",
            "710379c0119840b2919354ca0eb8168a",
            "721134531c9d426cad45877d7760dedd",
            "34c5abaeaca240ae9610a7c6c03e2992",
            "498223edfee444d79fbfb5729cd7b1b5",
            "4e030f30944f4da58ea23071874017af",
            "0cc219f470d847b8add4926ac2980b14",
            "6f8f8354bf9e4d438b2b145b30bc31d5",
            "239573d80f7b45469cc47f5db5e52624",
            "a3837c636acb42628a558c83a90e6c92"
          ]
        },
        "id": "9JgpFVFJhh-k",
        "outputId": "4ab96366-a558-434e-ccc8-60f3d3a1721c"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n",
            "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
            "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
            "You will be able to reuse this secret in all of your notebooks.\n",
            "Please note that authentication is recommended but still optional to access public models or datasets.\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "0b7c9143f16546ccb0114dd6f5f05e1f",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/83.0 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "289357be8ae3454286754e55c8f40c3b",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "vocab.txt: 0.00B [00:00, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "6cebbc05a32843e1a0243addc246a3ea",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "from collections import defaultdict\n",
        "import csv\n",
        "import random\n",
        "import time\n",
        "\n",
        "from datasets import load_metric\n",
        "import numpy as np\n",
        "from sklearn.model_selection import train_test_split\n",
        "import torch\n",
        "from transformers import BertTokenizerFast\n",
        "\n",
        "mlength = 512\n",
        "nclasses = 4\n",
        "start = time.time()\n",
        "tokenizer = BertTokenizerFast.from_pretrained('deepset/gbert-large')\n",
        "\n",
        "class PSCDataset(torch.utils.data.Dataset):\n",
        "    def __init__(self, encodings, labels):\n",
        "        self.encodings = encodings\n",
        "        self.labels = labels\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
        "        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.float)\n",
        "        return item\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.labels)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eGj5f_Tzfw0N"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support, precision_score, recall_score, precision_recall_curve\n",
        "\n",
        "# Define custom compute_metrics function\n",
        "def compute_metrics(pred):\n",
        "    logits, labels = pred\n",
        "\n",
        "    l1_precision, l1_recall, l1_thresholds = precision_recall_curve(labels[:, 0], logits[:, 0])\n",
        "    l1_f1 = 2 * l1_precision * l1_recall / (l1_precision + l1_recall)\n",
        "\n",
        "    l2_precision, l2_recall, l2_thresholds = precision_recall_curve(labels[:, 1], logits[:, 1])\n",
        "    l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n",
        "\n",
        "    l3_precision, l3_recall, l3_thresholds = precision_recall_curve(labels[:, 2], logits[:, 2])\n",
        "    l3_f1 = 2 * l3_precision * l3_recall / (l3_precision + l3_recall)\n",
        "\n",
        "    l4_precision, l4_recall, l4_thresholds = precision_recall_curve(labels[:, 3], logits[:, 3])\n",
        "    l4_f1 = 2 * l4_precision * l4_recall / (l4_precision + l4_recall)\n",
        "\n",
        "    return {\n",
        "        \"left_wing_f1\": max(l1_f1),\n",
        "        \"right_wing_f1\": max(l2_f1),\n",
        "        \"anti_elitism_f1\": max(l3_f1),\n",
        "        \"people_centrism_f1\": max(l4_f1),\n",
        "        \"macro_f1\": (max(l1_f1)+max(l2_f1)+max(l3_f1)+max(l4_f1))/4,\n",
        "    }\n",
        "\n",
        "def metric_helper(labels, logits):\n",
        "    labels = np.array(labels)  # Convert labels to NumPy array\n",
        "    logits = np.array(logits)  # Convert logits to NumPy array\n",
        "    l1_precision, l1_recall, l1_thresholds = precision_recall_curve(labels[:, 0], logits[:, 0])\n",
        "    l1_f1 = 2 * l1_precision * l1_recall / (l1_precision + l1_recall)\n",
        "    max_l1_f1 = max(l1_f1)\n",
        "    i = np.where(l1_f1 == max_l1_f1)[0][0]\n",
        "    max_l1_precision = l1_precision[i]\n",
        "    max_l1_recall = l1_recall[i]\n",
        "    max_l1_threshold = l1_thresholds[i]\n",
        "\n",
        "\n",
        "    l2_precision, l2_recall, l2_thresholds = precision_recall_curve(labels[:, 1], logits[:, 1])\n",
        "    l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n",
        "    max_l2_f1 = max(l2_f1)\n",
        "    i = np.where(l2_f1 == max_l2_f1)[0][0]\n",
        "    max_l2_precision = l2_precision[i]\n",
        "    max_l2_recall = l2_recall[i]\n",
        "    max_l2_threshold = l2_thresholds[i]\n",
        "\n",
        "    l3_precision, l3_recall, l3_thresholds = precision_recall_curve(labels[:, 2], logits[:, 2])\n",
        "    l3_f1 = 2 * l3_precision * l3_recall / (l3_precision + l3_recall)\n",
        "    max_l3_f1 = max(l3_f1)\n",
        "    i = np.where(l3_f1 == max_l3_f1)[0][0]\n",
        "    max_l3_precision = l3_precision[i]\n",
        "    max_l3_recall = l3_recall[i]\n",
        "    max_l3_threshold = l3_thresholds[i]\n",
        "\n",
        "    l4_precision, l4_recall, l4_thresholds = precision_recall_curve(labels[:, 3], logits[:, 3])\n",
        "    l4_f1 = 2 * l4_precision * l4_recall / (l4_precision + l4_recall)\n",
        "    max_l4_f1 = max(l4_f1)\n",
        "    i = np.where(l4_f1 == max_l4_f1)[0][0]\n",
        "    max_l4_precision = l4_precision[i]\n",
        "    max_l4_recall = l4_recall[i]\n",
        "    max_l4_threshold = l4_thresholds[i]\n",
        "\n",
        "    macro_precision = (max_l1_precision + max_l2_precision + max_l3_precision + max_l4_precision) / 4\n",
        "    macro_recall = (max_l1_recall + max_l2_recall + max_l3_recall + max_l4_recall) / 4\n",
        "    macro_f1 = (max_l1_f1 + max_l2_f1 + max_l3_f1 + max_l4_f1) / 4\n",
        "\n",
        "\n",
        "    # Generate predictions using the max F1 threshold\n",
        "    predicted = (logits >= [max_l1_threshold, max_l2_threshold, max_l3_threshold, max_l4_threshold]).astype(int)\n",
        "\n",
        "    # Calculate TP, FP, FN, TN\n",
        "    tp = np.sum((predicted == 1) & (labels == 1))  # True Positives\n",
        "    fp = np.sum((predicted == 1) & (labels == 0))  # False Positives\n",
        "    fn = np.sum((predicted == 0) & (labels == 1))  # False Negatives\n",
        "    tn = np.sum((predicted == 0) & (labels == 0))  # True Negatives\n",
        "    micro_precision = tp/(tp+fp)\n",
        "    micro_recall = tp/(tp+fn)\n",
        "    micro_f1 = 2 * micro_precision * micro_recall / (micro_precision + micro_recall)\n",
        "\n",
        "    # return 4 * 3 + 3 + 3 = 18 values\n",
        "    return max_l1_precision, max_l1_recall, max_l1_f1, max_l2_precision, max_l2_recall, max_l2_f1, max_l3_precision, max_l3_recall, max_l3_f1, max_l4_precision, max_l4_recall, max_l4_f1, micro_precision, micro_recall, micro_f1, macro_precision, macro_recall, macro_f1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true,
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "8TRFGAbihF_p",
        "outputId": "9e12be2e-9902-410a-8389-9f77a5860acf"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at deepset/gbert-large and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at deepset/gbert-large and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='2451' max='2451' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [2451/2451 07:57, Epoch 3/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Left Wing F1</th>\n",
              "      <th>Right Wing F1</th>\n",
              "      <th>Anti Elitism F1</th>\n",
              "      <th>People Centrism F1</th>\n",
              "      <th>Macro F1</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.231900</td>\n",
              "      <td>0.256091</td>\n",
              "      <td>0.690000</td>\n",
              "      <td>0.720930</td>\n",
              "      <td>0.827027</td>\n",
              "      <td>0.656250</td>\n",
              "      <td>0.723552</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.191400</td>\n",
              "      <td>0.262710</td>\n",
              "      <td>0.730769</td>\n",
              "      <td>0.750000</td>\n",
              "      <td>0.827586</td>\n",
              "      <td>0.686047</td>\n",
              "      <td>0.748600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.066700</td>\n",
              "      <td>0.289579</td>\n",
              "      <td>0.734300</td>\n",
              "      <td>0.788462</td>\n",
              "      <td>0.826531</td>\n",
              "      <td>0.666667</td>\n",
              "      <td>0.753990</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/tmp/ipython-input-3850816672.py:12: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n",
            "/tmp/ipython-input-3850816672.py:18: RuntimeWarning: invalid value encountered in divide\n",
            "  l4_f1 = 2 * l4_precision * l4_recall / (l4_precision + l4_recall)\n",
            "/tmp/ipython-input-3850816672.py:12: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n"
          ]
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/tmp/ipython-input-3850816672.py:12: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n",
            "/tmp/ipython-input-3850816672.py:41: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n"
          ]
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "======================\n",
            "0.7539895835703289 1.5e-05 11\n",
            "Metric:  Precision, Recall, F1\n",
            "Left wing: 0.70, 0.70, 0.70\n",
            "Right wing: 0.66, 0.76, 0.71\n",
            "Anti elitism: 0.78, 0.94, 0.85\n",
            "People centrism: 0.67, 0.74, 0.71\n",
            "Micro: 0.73, 0.83, 0.78\n",
            "Macro: 0.70, 0.79, 0.74\n",
            "======================\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at deepset/gbert-large and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at deepset/gbert-large and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='2451' max='2451' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [2451/2451 08:11, Epoch 3/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Left Wing F1</th>\n",
              "      <th>Right Wing F1</th>\n",
              "      <th>Anti Elitism F1</th>\n",
              "      <th>People Centrism F1</th>\n",
              "      <th>Macro F1</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.239900</td>\n",
              "      <td>0.260514</td>\n",
              "      <td>0.688525</td>\n",
              "      <td>0.703297</td>\n",
              "      <td>0.832487</td>\n",
              "      <td>0.652406</td>\n",
              "      <td>0.719179</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.190500</td>\n",
              "      <td>0.258866</td>\n",
              "      <td>0.733333</td>\n",
              "      <td>0.720000</td>\n",
              "      <td>0.832000</td>\n",
              "      <td>0.660550</td>\n",
              "      <td>0.736471</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.077000</td>\n",
              "      <td>0.278790</td>\n",
              "      <td>0.748603</td>\n",
              "      <td>0.736842</td>\n",
              "      <td>0.826196</td>\n",
              "      <td>0.645161</td>\n",
              "      <td>0.739201</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/tmp/ipython-input-3850816672.py:12: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n",
            "/tmp/ipython-input-3850816672.py:12: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n",
            "/tmp/ipython-input-3850816672.py:12: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n"
          ]
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/tmp/ipython-input-3850816672.py:12: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n",
            "/tmp/ipython-input-3850816672.py:41: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n",
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at deepset/gbert-large and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at deepset/gbert-large and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='2451' max='2451' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [2451/2451 08:02, Epoch 3/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Left Wing F1</th>\n",
              "      <th>Right Wing F1</th>\n",
              "      <th>Anti Elitism F1</th>\n",
              "      <th>People Centrism F1</th>\n",
              "      <th>Macro F1</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.243000</td>\n",
              "      <td>0.265153</td>\n",
              "      <td>0.662983</td>\n",
              "      <td>0.666667</td>\n",
              "      <td>0.841026</td>\n",
              "      <td>0.666667</td>\n",
              "      <td>0.709336</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.209500</td>\n",
              "      <td>0.253353</td>\n",
              "      <td>0.710660</td>\n",
              "      <td>0.704762</td>\n",
              "      <td>0.825065</td>\n",
              "      <td>0.677966</td>\n",
              "      <td>0.729613</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.118800</td>\n",
              "      <td>0.259870</td>\n",
              "      <td>0.723404</td>\n",
              "      <td>0.725275</td>\n",
              "      <td>0.828338</td>\n",
              "      <td>0.648889</td>\n",
              "      <td>0.731476</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/tmp/ipython-input-3850816672.py:12: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n",
            "/tmp/ipython-input-3850816672.py:12: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n",
            "/tmp/ipython-input-3850816672.py:12: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n"
          ]
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/tmp/ipython-input-3850816672.py:12: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n",
            "/tmp/ipython-input-3850816672.py:41: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n",
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at deepset/gbert-large and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at deepset/gbert-large and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='2451' max='2451' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [2451/2451 08:04, Epoch 3/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Left Wing F1</th>\n",
              "      <th>Right Wing F1</th>\n",
              "      <th>Anti Elitism F1</th>\n",
              "      <th>People Centrism F1</th>\n",
              "      <th>Macro F1</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.254200</td>\n",
              "      <td>0.253160</td>\n",
              "      <td>0.701299</td>\n",
              "      <td>0.745455</td>\n",
              "      <td>0.841076</td>\n",
              "      <td>0.718615</td>\n",
              "      <td>0.751611</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.225500</td>\n",
              "      <td>0.238757</td>\n",
              "      <td>0.742138</td>\n",
              "      <td>0.795918</td>\n",
              "      <td>0.843243</td>\n",
              "      <td>0.716814</td>\n",
              "      <td>0.774529</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.067000</td>\n",
              "      <td>0.272172</td>\n",
              "      <td>0.763158</td>\n",
              "      <td>0.819048</td>\n",
              "      <td>0.838046</td>\n",
              "      <td>0.738589</td>\n",
              "      <td>0.789710</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "======================\n",
            "0.7897102494740729 1.5e-05 12\n",
            "Metric:  Precision, Recall, F1\n",
            "Left wing: 0.75, 0.75, 0.75\n",
            "Right wing: 0.69, 0.79, 0.74\n",
            "Anti elitism: 0.86, 0.84, 0.85\n",
            "People centrism: 0.65, 0.82, 0.72\n",
            "Micro: 0.76, 0.81, 0.79\n",
            "Macro: 0.74, 0.80, 0.76\n",
            "======================\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at deepset/gbert-large and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at deepset/gbert-large and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='2451' max='2451' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [2451/2451 07:59, Epoch 3/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Left Wing F1</th>\n",
              "      <th>Right Wing F1</th>\n",
              "      <th>Anti Elitism F1</th>\n",
              "      <th>People Centrism F1</th>\n",
              "      <th>Macro F1</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.245700</td>\n",
              "      <td>0.254602</td>\n",
              "      <td>0.714286</td>\n",
              "      <td>0.703704</td>\n",
              "      <td>0.844920</td>\n",
              "      <td>0.729614</td>\n",
              "      <td>0.748131</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.201700</td>\n",
              "      <td>0.242625</td>\n",
              "      <td>0.783133</td>\n",
              "      <td>0.776699</td>\n",
              "      <td>0.828283</td>\n",
              "      <td>0.731092</td>\n",
              "      <td>0.779802</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.082400</td>\n",
              "      <td>0.261075</td>\n",
              "      <td>0.746667</td>\n",
              "      <td>0.788991</td>\n",
              "      <td>0.837333</td>\n",
              "      <td>0.718447</td>\n",
              "      <td>0.772859</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at deepset/gbert-large and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at deepset/gbert-large and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='2451' max='2451' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [2451/2451 07:53, Epoch 3/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Left Wing F1</th>\n",
              "      <th>Right Wing F1</th>\n",
              "      <th>Anti Elitism F1</th>\n",
              "      <th>People Centrism F1</th>\n",
              "      <th>Macro F1</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.234900</td>\n",
              "      <td>0.264393</td>\n",
              "      <td>0.677966</td>\n",
              "      <td>0.690909</td>\n",
              "      <td>0.840206</td>\n",
              "      <td>0.722467</td>\n",
              "      <td>0.732887</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.264800</td>\n",
              "      <td>0.248827</td>\n",
              "      <td>0.738095</td>\n",
              "      <td>0.701031</td>\n",
              "      <td>0.836461</td>\n",
              "      <td>0.732673</td>\n",
              "      <td>0.752065</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.145600</td>\n",
              "      <td>0.248549</td>\n",
              "      <td>0.714286</td>\n",
              "      <td>0.700855</td>\n",
              "      <td>0.840426</td>\n",
              "      <td>0.730594</td>\n",
              "      <td>0.746540</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at deepset/gbert-large and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at deepset/gbert-large and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='2451' max='2451' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [2451/2451 08:00, Epoch 3/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Left Wing F1</th>\n",
              "      <th>Right Wing F1</th>\n",
              "      <th>Anti Elitism F1</th>\n",
              "      <th>People Centrism F1</th>\n",
              "      <th>Macro F1</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.248400</td>\n",
              "      <td>0.235200</td>\n",
              "      <td>0.719512</td>\n",
              "      <td>0.632653</td>\n",
              "      <td>0.869565</td>\n",
              "      <td>0.623116</td>\n",
              "      <td>0.711212</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.098300</td>\n",
              "      <td>0.251969</td>\n",
              "      <td>0.726115</td>\n",
              "      <td>0.689655</td>\n",
              "      <td>0.868217</td>\n",
              "      <td>0.677419</td>\n",
              "      <td>0.740352</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.076600</td>\n",
              "      <td>0.271290</td>\n",
              "      <td>0.748466</td>\n",
              "      <td>0.712329</td>\n",
              "      <td>0.870712</td>\n",
              "      <td>0.681319</td>\n",
              "      <td>0.753207</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/tmp/ipython-input-3850816672.py:15: RuntimeWarning: invalid value encountered in divide\n",
            "  l3_f1 = 2 * l3_precision * l3_recall / (l3_precision + l3_recall)\n",
            "/tmp/ipython-input-3850816672.py:9: RuntimeWarning: invalid value encountered in divide\n",
            "  l1_f1 = 2 * l1_precision * l1_recall / (l1_precision + l1_recall)\n",
            "/tmp/ipython-input-3850816672.py:12: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n"
          ]
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/tmp/ipython-input-3850816672.py:12: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n",
            "/tmp/ipython-input-3850816672.py:41: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n"
          ]
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "======================\n",
            "0.7532065267915224 1.5e-05 13\n",
            "Metric:  Precision, Recall, F1\n",
            "Left wing: 0.74, 0.70, 0.72\n",
            "Right wing: 0.70, 0.64, 0.67\n",
            "Anti elitism: 0.82, 0.89, 0.86\n",
            "People centrism: 0.65, 0.78, 0.71\n",
            "Micro: 0.75, 0.80, 0.78\n",
            "Macro: 0.73, 0.75, 0.74\n",
            "======================\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at deepset/gbert-large and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at deepset/gbert-large and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='2451' max='2451' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [2451/2451 08:05, Epoch 3/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Left Wing F1</th>\n",
              "      <th>Right Wing F1</th>\n",
              "      <th>Anti Elitism F1</th>\n",
              "      <th>People Centrism F1</th>\n",
              "      <th>Macro F1</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.259100</td>\n",
              "      <td>0.236966</td>\n",
              "      <td>0.742138</td>\n",
              "      <td>0.659574</td>\n",
              "      <td>0.855721</td>\n",
              "      <td>0.627273</td>\n",
              "      <td>0.721177</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.119600</td>\n",
              "      <td>0.246814</td>\n",
              "      <td>0.727273</td>\n",
              "      <td>0.708861</td>\n",
              "      <td>0.857143</td>\n",
              "      <td>0.674556</td>\n",
              "      <td>0.741958</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.092600</td>\n",
              "      <td>0.262162</td>\n",
              "      <td>0.741573</td>\n",
              "      <td>0.659091</td>\n",
              "      <td>0.858667</td>\n",
              "      <td>0.678788</td>\n",
              "      <td>0.734530</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/tmp/ipython-input-3850816672.py:12: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n",
            "/tmp/ipython-input-3850816672.py:12: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n",
            "/tmp/ipython-input-3850816672.py:12: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n"
          ]
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/tmp/ipython-input-3850816672.py:12: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n",
            "/tmp/ipython-input-3850816672.py:41: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n",
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at deepset/gbert-large and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at deepset/gbert-large and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='2451' max='2451' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [2451/2451 08:06, Epoch 3/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Left Wing F1</th>\n",
              "      <th>Right Wing F1</th>\n",
              "      <th>Anti Elitism F1</th>\n",
              "      <th>People Centrism F1</th>\n",
              "      <th>Macro F1</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.250200</td>\n",
              "      <td>0.253269</td>\n",
              "      <td>0.740260</td>\n",
              "      <td>0.565657</td>\n",
              "      <td>0.872340</td>\n",
              "      <td>0.642857</td>\n",
              "      <td>0.705278</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.149800</td>\n",
              "      <td>0.240159</td>\n",
              "      <td>0.736842</td>\n",
              "      <td>0.659574</td>\n",
              "      <td>0.871391</td>\n",
              "      <td>0.651429</td>\n",
              "      <td>0.729809</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.151300</td>\n",
              "      <td>0.241813</td>\n",
              "      <td>0.754286</td>\n",
              "      <td>0.681818</td>\n",
              "      <td>0.878307</td>\n",
              "      <td>0.648649</td>\n",
              "      <td>0.740765</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/tmp/ipython-input-3850816672.py:12: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n",
            "/tmp/ipython-input-3850816672.py:12: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n",
            "/tmp/ipython-input-3850816672.py:12: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n"
          ]
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/tmp/ipython-input-3850816672.py:12: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n",
            "/tmp/ipython-input-3850816672.py:41: RuntimeWarning: invalid value encountered in divide\n",
            "  l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n"
          ]
        }
      ],
      "source": [
        "seeds = range(11,14)\n",
        "epochs = 3\n",
        "records = defaultdict(list)\n",
        "for seed in seeds:\n",
        "  np.random.seed(seed)\n",
        "  torch.manual_seed(seed)\n",
        "  random.seed(seed)\n",
        "\n",
        "  # 1759 (20%) for testing\n",
        "  shuffled_data = data.sample(random_state = seed, frac = 1)\n",
        "  test_data = shuffled_data[:1759]\n",
        "  dev_data = shuffled_data[1759:(1759 + 500)]\n",
        "  train_data = shuffled_data[(1759 + 500):]\n",
        "\n",
        "  test_labels = list(zip(test_data[\"is_left_wing\"].tolist(), test_data[\"is_right_wing\"].tolist(), test_data[\"is_anti_elitism\"].tolist(), test_data[\"is_people_centrism\"].tolist()))\n",
        "  test_texts = test_data[\"text\"].tolist()\n",
        "\n",
        "  train_labels = list(zip(train_data[\"is_left_wing\"].tolist(), train_data[\"is_right_wing\"].tolist(), train_data[\"is_anti_elitism\"].tolist(), train_data[\"is_people_centrism\"].tolist()))\n",
        "  train_texts = train_data[\"text\"].tolist()\n",
        "\n",
        "  dev_labels = list(zip(dev_data[\"is_left_wing\"].tolist(), dev_data[\"is_right_wing\"].tolist(), dev_data[\"is_anti_elitism\"].tolist(), dev_data[\"is_people_centrism\"].tolist()))\n",
        "  dev_texts = dev_data[\"text\"].tolist()\n",
        "\n",
        "  zipped = list(zip(train_texts, train_labels))\n",
        "  random.shuffle(zipped)\n",
        "  train_texts, train_labels = zip(*zipped)\n",
        "\n",
        "  train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=mlength)\n",
        "  dev_encodings = tokenizer(dev_texts, truncation=True, padding=True, max_length = mlength)\n",
        "  test_encodings = tokenizer(test_texts, truncation=True, padding=True, max_length= mlength)\n",
        "\n",
        "  train_dataset = PSCDataset(train_encodings, train_labels)\n",
        "  dev_dataset = PSCDataset(dev_encodings, dev_labels)\n",
        "  test_dataset = PSCDataset(test_encodings, test_labels)\n",
        "\n",
        "  from transformers import BertForSequenceClassification, TrainingArguments, Trainer\n",
        "  max_dev_f1 = 0\n",
        "  test_f1 = 0\n",
        "  for learning_rate in [15e-6, 10e-6, 5e-6]:\n",
        "    training_args = TrainingArguments(\n",
        "        output_dir=\"./results\",          # output directory\n",
        "        num_train_epochs=epochs,         # total number of training epochs\n",
        "        per_device_train_batch_size=8,   # batch size per device during training\n",
        "        per_device_eval_batch_size=64,   # batch size for evaluation\n",
        "        warmup_steps=0,                  # number of warmup steps for learning rate scheduler\n",
        "        weight_decay=0.01,               # strength of weight decay\n",
        "        logging_dir='./logs',            # directory for storing logs\n",
        "        logging_steps=10,\n",
        "        learning_rate = learning_rate,\n",
        "        save_strategy= \"epoch\",\n",
        "        eval_strategy=\"epoch\",\n",
        "        load_best_model_at_end= True,\n",
        "        metric_for_best_model=\"macro_f1\",\n",
        "        save_total_limit = 2,\n",
        "        seed = seed,\n",
        "        report_to = \"none\",\n",
        "    )\n",
        "\n",
        "    def model_init():\n",
        "        return BertForSequenceClassification.from_pretrained(\"deepset/gbert-large\", num_labels=nclasses, problem_type=\"multi_label_classification\")\n",
        "\n",
        "    trainer = Trainer(\n",
        "        model_init=model_init,               # the instantiated 🤗 Transformers model to be trained\n",
        "        args=training_args,                  # training arguments, defined above\n",
        "        train_dataset=train_dataset,         # training dataset\n",
        "        eval_dataset=dev_dataset,            # evaluation dataset\n",
        "        compute_metrics=compute_metrics,     # compute_metrics\n",
        "        )\n",
        "\n",
        "    trainer.train()\n",
        "    # dev accuracy\n",
        "    predictions = trainer.predict(dev_dataset)\n",
        "    logits = predictions.predictions\n",
        "    labels = dev_dataset.labels\n",
        "    dev_macro_f1 = metric_helper(labels, logits)[-1]\n",
        "\n",
        "    if dev_macro_f1 > max_dev_f1: # update test accuracy once dev accuracy improves\n",
        "      max_dev_f1 = dev_macro_f1\n",
        "      predictions = trainer.predict(test_dataset)\n",
        "      logits = predictions.predictions\n",
        "      labels = test_dataset.labels\n",
        "      max_l1_precision, max_l1_recall, max_l1_f1, max_l2_precision, max_l2_recall, max_l2_f1, max_l3_precision, max_l3_recall, max_l3_f1, max_l4_precision, max_l4_recall, max_l4_f1, micro_precision, micro_recall, micro_f1, macro_precision, macro_recall, macro_f1 = metric_helper(labels, logits)\n",
        "      print(\"======================\")\n",
        "      print(max_dev_f1, learning_rate, seed)\n",
        "      print(\"Metric:  Precision, Recall, F1\")\n",
        "      print(f\"Left wing: {max_l1_precision:.2f}, {max_l1_recall:.2f}, {max_l1_f1:.2f}\")\n",
        "      print(f\"Right wing: {max_l2_precision:.2f}, {max_l2_recall:.2f}, {max_l2_f1:.2f}\")\n",
        "      print(f\"Anti elitism: {max_l3_precision:.2f}, {max_l3_recall:.2f}, {max_l3_f1:.2f}\")\n",
        "      print(f\"People centrism: {max_l4_precision:.2f}, {max_l4_recall:.2f}, {max_l4_f1:.2f}\")\n",
        "      print(f\"Micro: {micro_precision:.2f}, {micro_recall:.2f}, {micro_f1:.2f}\")\n",
        "      print(f\"Macro: {macro_precision:.2f}, {macro_recall:.2f}, {macro_f1:.2f}\")\n",
        "      print(\"======================\")\n",
        "      records[seed] = np.array([max_l1_precision, max_l1_recall, max_l1_f1, max_l2_precision, max_l2_recall, max_l2_f1, max_l3_precision, max_l3_recall, max_l3_f1, max_l4_precision, max_l4_recall, max_l4_f1, micro_precision, micro_recall, micro_f1, macro_precision, macro_recall, macro_f1])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "HNaQtJNxscVI",
        "outputId": "260a9a57-02bd-44ad-fda1-658234059865"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "11 [0.703125   0.703125   0.703125   0.66091954 0.7615894  0.70769231\n",
            " 0.77538071 0.93711656 0.84861111 0.67435159 0.74285714 0.70694864\n",
            " 0.7284345  0.82969432 0.77577407 0.70344421 0.78617203 0.74159426]\n",
            "12 [0.74712644 0.74712644 0.74712644 0.69035533 0.78612717 0.73513514\n",
            " 0.8592     0.84433962 0.851705   0.645      0.81904762 0.72167832\n",
            " 0.75927175 0.81299639 0.78521618 0.73542044 0.79916021 0.76391122]\n",
            "13 [0.74230769 0.70181818 0.72149533 0.7027027  0.64197531 0.67096774\n",
            " 0.82496413 0.89285714 0.85756898 0.65346535 0.77647059 0.70967742\n",
            " 0.75281643 0.79943702 0.77542662 0.73085997 0.75328031 0.73992737]\n",
            "[0.73085304 0.71735654 0.72391559 0.68465919 0.72989729 0.70459839\n",
            " 0.81984828 0.89143778 0.85262836 0.65760564 0.77945845 0.71276813\n",
            " 0.7468409  0.81404258 0.77880562 0.72324154 0.77953751 0.74847762]\n"
          ]
        }
      ],
      "source": [
        "for seed, metrics in records.items():\n",
        "  print(seed, metrics)\n",
        "print(np.mean(list(records.values()), axis = 0))"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "A100",
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "035d1e58ab2a4ff7bd80d74fc9581958": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "04e632bbd11b405ca6579c441538a446": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "0b7c9143f16546ccb0114dd6f5f05e1f": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_588d00688b8444f0a0dd4aa2ef08f84a",
              "IPY_MODEL_748922c942404e878bbe74c957c2b055",
              "IPY_MODEL_ca90952353d941d1a5efd123c4524fc3"
            ],
            "layout": "IPY_MODEL_035d1e58ab2a4ff7bd80d74fc9581958"
          }
        },
        "0cc219f470d847b8add4926ac2980b14": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "239573d80f7b45469cc47f5db5e52624": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "259407da05744d27ba945cab73d297c3": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "289357be8ae3454286754e55c8f40c3b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_501fc476aa9c46efa054d5357d29c254",
              "IPY_MODEL_5afedaf6e44c4612baf7cb9ddb0c2d69",
              "IPY_MODEL_e1aaa790ef8f4a66bdebca8daa6c029c"
            ],
            "layout": "IPY_MODEL_48c0a120298947c2a6e2082d16375a0c"
          }
        },
        "34c5abaeaca240ae9610a7c6c03e2992": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "3a76ddf6f4d7436b882b4c686a338f65": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "48c0a120298947c2a6e2082d16375a0c": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "498223edfee444d79fbfb5729cd7b1b5": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "4e030f30944f4da58ea23071874017af": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "501fc476aa9c46efa054d5357d29c254": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_cf9b78f2b4f64d98af4e66ad8e523757",
            "placeholder": "​",
            "style": "IPY_MODEL_259407da05744d27ba945cab73d297c3",
            "value": "vocab.txt: "
          }
        },
        "588d00688b8444f0a0dd4aa2ef08f84a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c87babf60681452e80d6c68b0fa644c3",
            "placeholder": "​",
            "style": "IPY_MODEL_60d65ed196de4d6782741c47e548c43d",
            "value": "tokenizer_config.json: 100%"
          }
        },
        "5afedaf6e44c4612baf7cb9ddb0c2d69": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_5d715a75665846f5a174e49f4ee20029",
            "max": 1,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_6e05804bb0694170bddb8bc075154e50",
            "value": 1
          }
        },
        "5d715a75665846f5a174e49f4ee20029": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "20px"
          }
        },
        "60d65ed196de4d6782741c47e548c43d": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "67d3b3e533d640beae9618ede0d92c61": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "6cebbc05a32843e1a0243addc246a3ea": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_7e2e00b07f2440c797c8dd4cde418016",
              "IPY_MODEL_710379c0119840b2919354ca0eb8168a",
              "IPY_MODEL_721134531c9d426cad45877d7760dedd"
            ],
            "layout": "IPY_MODEL_34c5abaeaca240ae9610a7c6c03e2992"
          }
        },
        "6e05804bb0694170bddb8bc075154e50": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "6f8f8354bf9e4d438b2b145b30bc31d5": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "710379c0119840b2919354ca0eb8168a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0cc219f470d847b8add4926ac2980b14",
            "max": 363,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_6f8f8354bf9e4d438b2b145b30bc31d5",
            "value": 363
          }
        },
        "721134531c9d426cad45877d7760dedd": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_239573d80f7b45469cc47f5db5e52624",
            "placeholder": "​",
            "style": "IPY_MODEL_a3837c636acb42628a558c83a90e6c92",
            "value": " 363/363 [00:00&lt;00:00, 49.6kB/s]"
          }
        },
        "748922c942404e878bbe74c957c2b055": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_04e632bbd11b405ca6579c441538a446",
            "max": 83,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_67d3b3e533d640beae9618ede0d92c61",
            "value": 83
          }
        },
        "7e2e00b07f2440c797c8dd4cde418016": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_498223edfee444d79fbfb5729cd7b1b5",
            "placeholder": "​",
            "style": "IPY_MODEL_4e030f30944f4da58ea23071874017af",
            "value": "config.json: 100%"
          }
        },
        "7e30e053f1894bd99fbfbff432ae00a7": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "94a05c4a4a3141d4b53736736105082a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "a3837c636acb42628a558c83a90e6c92": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "c87babf60681452e80d6c68b0fa644c3": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ca90952353d941d1a5efd123c4524fc3": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_7e30e053f1894bd99fbfbff432ae00a7",
            "placeholder": "​",
            "style": "IPY_MODEL_3a76ddf6f4d7436b882b4c686a338f65",
            "value": " 83.0/83.0 [00:00&lt;00:00, 10.1kB/s]"
          }
        },
        "cf9b78f2b4f64d98af4e66ad8e523757": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e1aaa790ef8f4a66bdebca8daa6c029c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_ebca5cc51c674a4e8d6363d4dca3e5fb",
            "placeholder": "​",
            "style": "IPY_MODEL_94a05c4a4a3141d4b53736736105082a",
            "value": " 240k/? [00:00&lt;00:00, 17.0MB/s]"
          }
        },
        "ebca5cc51c674a4e8d6363d4dca3e5fb": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}